1 /**
2    Support the automatic implementation of test doubles via programmable mocks.
3  */
4 module unit_threaded.mock;
5 
6 import unit_threaded.from;
7 
8 alias Identity(alias T) = T;
9 private enum isPrivate(T, string member) = !__traits(compiles, __traits(getMember, T, member));
10 
11 string implMixinStr(T)() {
12     import std.array : join;
13     import std.format : format;
14     import std.range : iota;
15     import std.traits : functionAttributes, FunctionAttribute, Parameters,
16         ReturnType, arity;
17     import std.conv : text;
18 
19     if (!__ctfe)
20         return null;
21 
22     string[] lines;
23 
24     string getOverload(in string memberName, in int i) {
25         return `Identity!(__traits(getOverloads, T, "%s")[%s])`.format(memberName, i);
26     }
27 
28     foreach (memberName; __traits(allMembers, T)) {
29 
30         static if (!isPrivate!(T, memberName)) {
31 
32             alias member = Identity!(__traits(getMember, T, memberName));
33 
34             static if (__traits(isVirtualMethod, member)) {
35                 foreach (i, overload; __traits(getOverloads, T, memberName)) {
36 
37                     static if (!(functionAttributes!member & FunctionAttribute.const_)
38                             && !(functionAttributes!member & FunctionAttribute.const_)) {
39 
40                         enum overloadName = text(memberName, "_", i);
41 
42                         enum overloadString = getOverload(memberName, i);
43                         lines ~= "private alias %s_parameters = Parameters!(%s);".format(overloadName,
44                                 overloadString);
45                         lines ~= "private alias %s_returnType = ReturnType!(%s);".format(overloadName,
46                                 overloadString);
47 
48                         static if (functionAttributes!member & FunctionAttribute.nothrow_)
49                             enum tryIndent = "    ";
50                         else
51                             enum tryIndent = "";
52 
53                         static if (is(ReturnType!member == void))
54                             enum returnDefault = "";
55                         else {
56                             enum varName = overloadName ~ `_returnValues`;
57                             lines ~= `%s_returnType[] %s;`.format(overloadName, varName);
58                             lines ~= "";
59                             enum returnDefault = [
60                                     `    if(` ~ varName ~ `.length > 0) {`,
61                                     `        auto ret = ` ~ varName ~ `[0];`,
62                                     `        ` ~ varName ~ ` = ` ~ varName ~ `[1..$];`,
63                                     `        return ret;`,
64                                     `    } else`,
65                                     `        return %s_returnType.init;`.format(overloadName)
66                                 ];
67                         }
68 
69                         lines ~= `override ` ~ overloadName ~ "_returnType " ~ memberName ~ typeAndArgsParens!(
70                                 Parameters!overload)(overloadName) ~ " "
71                             ~ functionAttributesString!member ~ ` {`;
72 
73                         static if (functionAttributes!member & FunctionAttribute.nothrow_)
74                             lines ~= "try {";
75 
76                         lines ~= tryIndent ~ `    calledFuncs ~= "` ~ memberName ~ `";`;
77                         lines ~= tryIndent ~ `    calledValues ~= tuple` ~ argNamesParens(
78                                 arity!member) ~ `.to!string;`;
79 
80                         static if (functionAttributes!member & FunctionAttribute.nothrow_)
81                             lines ~= "    } catch(Exception) {}";
82 
83                         lines ~= returnDefault;
84 
85                         lines ~= `}`;
86                         lines ~= "";
87                     }
88                 }
89             }
90         }
91     }
92 
93     return lines.join("\n");
94 }
95 
96 private string argNamesParens(int N) @safe pure {
97     if (!__ctfe)
98         return null;
99     return "(" ~ argNames(N) ~ ")";
100 }
101 
102 private string argNames(int N) @safe pure {
103     import std.range;
104     import std.algorithm;
105     import std.conv;
106 
107     if (!__ctfe)
108         return null;
109     return iota(N).map!(a => "arg" ~ a.to!string).join(", ");
110 }
111 
112 private string typeAndArgsParens(T...)(string prefix) {
113     import std.array;
114     import std.conv;
115     import std.format : format;
116 
117     if (!__ctfe)
118         return null;
119 
120     string[] parts;
121 
122     foreach (i, t; T)
123         parts ~= "%s_parameters[%s] arg%s".format(prefix, i, i);
124     return "(" ~ parts.join(", ") ~ ")";
125 }
126 
127 private string functionAttributesString(alias F)() {
128     import std.traits : functionAttributes, FunctionAttribute;
129     import std.array : join;
130 
131     if (!__ctfe)
132         return null;
133 
134     string[] parts;
135 
136     const attrs = functionAttributes!F;
137 
138     if (attrs & FunctionAttribute.pure_)
139         parts ~= "pure";
140     if (attrs & FunctionAttribute.nothrow_)
141         parts ~= "nothrow";
142     if (attrs & FunctionAttribute.trusted)
143         parts ~= "@trusted";
144     if (attrs & FunctionAttribute.safe)
145         parts ~= "@safe";
146     if (attrs & FunctionAttribute.nogc)
147         parts ~= "@nogc";
148     if (attrs & FunctionAttribute.system)
149         parts ~= "@system";
150     // const and immutable can't be done since the mock needs
151     // to alter state
152     // if(attrs & FunctionAttribute.const_) parts ~= "const";
153     // if(attrs & FunctionAttribute.immutable_) parts ~= "immutable";
154     if (attrs & FunctionAttribute.shared_)
155         parts ~= "shared";
156 
157     return parts.join(" ");
158 }
159 
160 mixin template MockImplCommon() {
161     bool _verified;
162     string[] expectedFuncs;
163     string[] calledFuncs;
164     string[] expectedValues;
165     string[] calledValues;
166 
167     void expect(string funcName, V...)(auto ref V values) {
168         import std.conv : to;
169         import std.typecons : tuple;
170 
171         expectedFuncs ~= funcName;
172         static if (V.length > 0)
173             expectedValues ~= tuple(values).to!string;
174         else
175             expectedValues ~= "";
176     }
177 
178     void expectCalled(string func, string file = __FILE__, size_t line = __LINE__, V...)(
179             auto ref V values) {
180         expect!func(values);
181         verify(file, line);
182         _verified = false;
183     }
184 
185     void verify(string file = __FILE__, size_t line = __LINE__) @safe pure {
186         import std.range : repeat, take, join;
187         import std.conv : to;
188         import unit_threaded.should : fail, UnitTestException;
189 
190         if (_verified)
191             fail("Mock already _verified", file, line);
192 
193         _verified = true;
194 
195         for (int i = 0; i < expectedFuncs.length; ++i) {
196 
197             if (i >= calledFuncs.length)
198                 fail("Expected nth " ~ i.to!string ~ " call to " ~ expectedFuncs[i] ~ " did not happen",
199                         file, line);
200 
201             if (expectedFuncs[i] != calledFuncs[i])
202                 fail("Expected nth " ~ i.to!string ~ " call to " ~ expectedFuncs[i]
203                         ~ " but got " ~ calledFuncs[i] ~ " instead", file, line);
204 
205             if (expectedValues[i] != calledValues[i] && expectedValues[i] != "")
206                 throw new UnitTestException([expectedFuncs[i] ~ " was called with unexpected " ~ calledValues[i],
207                         " ".repeat.take(expectedFuncs[i].length + 4)
208                         .join ~ "instead of the expected " ~ expectedValues[i]], file, line);
209         }
210     }
211 }
212 
213 private enum isString(alias T) = is(typeof(T) == string);
214 
215 /**
216    A mock object that conforms to an interface/class.
217  */
218 struct Mock(T) {
219 
220     MockAbstract _impl;
221     alias _impl this;
222 
223     class MockAbstract : T {
224         import std.conv : to;
225         import std.traits : Parameters, ReturnType;
226         import std.typecons : tuple;
227 
228         //pragma(msg, "\nimplMixinStr for ", T, "\n\n", implMixinStr!T, "\n\n");
229         mixin(implMixinStr!T);
230         mixin MockImplCommon;
231     }
232 
233     ///
234     this(int /* force constructor*/ ) {
235         _impl = new MockAbstract;
236     }
237 
238     ///
239     ~this() pure @safe {
240         if (!_verified)
241             verify;
242     }
243 
244     /// Set the returnValue of a function to certain values.
245     void returnValue(string funcName, V...)(V values) {
246         assertFunctionIsVirtual!funcName;
247         return returnValue!(0, funcName)(values);
248     }
249 
250     /**
251        This version takes overloads into account. i is the overload
252        index. e.g.:
253        ---------
254        interface Interface { void foo(int); void foo(string); }
255        auto m = mock!Interface;
256        m.returnValue!(0, "foo"); // int overload
257        m.returnValue!(1, "foo"); // string overload
258        ---------
259      */
260     void returnValue(int i, string funcName, V...)(V values) {
261         assertFunctionIsVirtual!funcName;
262         import std.conv : text;
263 
264         enum varName = funcName ~ text(`_`, i, `_returnValues`);
265         foreach (v; values)
266             mixin(varName ~ ` ~=  v;`);
267     }
268 
269     private static void assertFunctionIsVirtual(string funcName)() {
270         alias member = Identity!(__traits(getMember, T, funcName));
271 
272         static assert(__traits(isVirtualMethod, member),
273                 "Cannot use returnValue on '" ~ funcName ~ "'");
274     }
275 }
276 
277 private string importsString(string module_, string[] Modules...) {
278     if (!__ctfe)
279         return null;
280 
281     auto ret = `import ` ~ module_ ~ ";\n";
282     foreach (extraModule; Modules) {
283         ret ~= `import ` ~ extraModule ~ ";\n";
284     }
285     return ret;
286 }
287 
288 /// Helper function for creating a Mock object.
289 auto mock(T)() {
290     return Mock!T(0);
291 }
292 
293 ///
294 @("mock interface positive test no params")
295 @safe pure unittest {
296     interface Foo {
297         int foo(int, string) @safe pure;
298         void bar() @safe pure;
299     }
300 
301     int fun(Foo f) {
302         return 2 * f.foo(5, "foobar");
303     }
304 
305     auto m = mock!Foo;
306     m.expect!"foo";
307     fun(m);
308 }
309 
310 ///
311 @("mock interface positive test with params")
312 @safe pure unittest {
313     import unit_threaded.asserts;
314 
315     interface Foo {
316         int foo(int, string) @safe pure;
317         void bar() @safe pure;
318     }
319 
320     int fun(Foo f) {
321         return 2 * f.foo(5, "foobar");
322     }
323 
324     auto m = mock!Foo;
325     m.expect!"foo"(5, "foobar");
326     fun(m);
327 }
328 
329 ///
330 @("interface expectCalled")
331 @safe pure unittest {
332     interface Foo {
333         int foo(int, string) @safe pure;
334         void bar() @safe pure;
335     }
336 
337     int fun(Foo f) {
338         return 2 * f.foo(5, "foobar");
339     }
340 
341     auto m = mock!Foo;
342     fun(m);
343     m.expectCalled!"foo"(5, "foobar");
344 }
345 
346 ///
347 @("interface return value")
348 @safe pure unittest {
349 
350     interface Foo {
351         int timesN(int i) @safe pure;
352     }
353 
354     int fun(Foo f) {
355         return f.timesN(3) * 2;
356     }
357 
358     auto m = mock!Foo;
359     m.returnValue!"timesN"(42);
360     immutable res = fun(m);
361     assert(res == 84);
362 }
363 
364 ///
365 @("interface return values")
366 @safe pure unittest {
367 
368     interface Foo {
369         int timesN(int i) @safe pure;
370     }
371 
372     int fun(Foo f) {
373         return f.timesN(3) * 2;
374     }
375 
376     auto m = mock!Foo;
377     m.returnValue!"timesN"(42, 12);
378     assert(fun(m) == 84);
379     assert(fun(m) == 24);
380     assert(fun(m) == 0);
381 }
382 
383 struct ReturnValues(string function_, T...)
384         if (from!"std.meta".allSatisfy!(isValue, T)) {
385     alias funcName = function_;
386     alias Values = T;
387 
388     static auto values() {
389         typeof(T[0])[] ret;
390         foreach (val; T) {
391             ret ~= val;
392         }
393         return ret;
394     }
395 }
396 
397 enum isReturnValue(alias T) = is(T : ReturnValues!U, U...);
398 enum isValue(alias T) = is(typeof(T));
399 
400 /**
401    Version of mockStruct that accepts 0 or more values of the same
402    type. Whatever function is called on it, these values will
403    be returned one by one. The limitation is that if more than one
404    function is called on the mock, they all return the same type
405  */
406 auto mockStruct(T...)(auto ref T returns) {
407 
408     struct Mock {
409 
410         MockImpl* _impl;
411         alias _impl this;
412 
413         static struct MockImpl {
414 
415             static if (T.length > 0) {
416                 alias FirstType = typeof(returns[0]);
417                 private FirstType[] _returnValues;
418             }
419 
420             mixin MockImplCommon;
421 
422             auto opDispatch(string funcName, V...)(auto ref V values) {
423 
424                 import std.conv : to;
425                 import std.typecons : tuple;
426 
427                 calledFuncs ~= funcName;
428                 calledValues ~= tuple(values).to!string;
429 
430                 static if (T.length > 0) {
431 
432                     if (_returnValues.length == 0)
433                         return typeof(_returnValues[0]).init;
434                     auto ret = _returnValues[0];
435                     _returnValues = _returnValues[1 .. $];
436                     return ret;
437                 }
438             }
439         }
440     }
441 
442     Mock m;
443     m._impl = new Mock.MockImpl;
444     static if (T.length > 0) {
445         foreach (r; returns)
446             m._impl._returnValues ~= r;
447     }
448 
449     return m;
450 }
451 
452 /**
453    Version of mockStruct that accepts a compile-time mapping
454    of function name to return values. Each template parameter
455    must be a value of type `ReturnValues`
456  */
457 auto mockStruct(T...)()
458         if (T.length > 0 && from!"std.meta".allSatisfy!(isReturnValue, T)) {
459 
460     struct Mock {
461         mixin MockImplCommon;
462 
463         int[string] _retIndices;
464 
465         auto opDispatch(string funcName, V...)(auto ref V values) {
466 
467             import std.conv : to;
468             import std.typecons : tuple;
469 
470             calledFuncs ~= funcName;
471             calledValues ~= tuple(values).to!string;
472 
473             foreach (retVal; T) {
474                 static if (retVal.funcName == funcName) {
475                     return retVal.values[_retIndices[funcName]++];
476                 }
477             }
478         }
479 
480         auto lefoofoo() {
481             return T[0].values[_retIndices["greet"]++];
482         }
483 
484     }
485 
486     Mock mock;
487 
488     foreach (retVal; T) {
489         mock._retIndices[retVal.funcName] = 0;
490     }
491 
492     return mock;
493 }
494 
495 ///
496 @("mock struct positive")
497 @safe pure unittest {
498     void fun(T)(T t) {
499         t.foobar;
500     }
501 
502     auto m = mockStruct;
503     m.expect!"foobar";
504     fun(m);
505     m.verify;
506 }
507 
508 ///
509 @("mock struct values positive")
510 @safe pure unittest {
511     void fun(T)(T t) {
512         t.foobar(2, "quux");
513     }
514 
515     auto m = mockStruct;
516     m.expect!"foobar"(2, "quux");
517     fun(m);
518     m.verify;
519 }
520 
521 ///
522 @("struct return value")
523 @safe pure unittest {
524 
525     int fun(T)(T f) {
526         return f.timesN(3) * 2;
527     }
528 
529     auto m = mockStruct(42, 12);
530     assert(fun(m) == 84);
531     assert(fun(m) == 24);
532     assert(fun(m) == 0);
533     m.expectCalled!"timesN";
534 }
535 
536 ///
537 @("struct expectCalled")
538 @safe pure unittest {
539     void fun(T)(T t) {
540         t.foobar(2, "quux");
541     }
542 
543     auto m = mockStruct;
544     fun(m);
545     m.expectCalled!"foobar"(2, "quux");
546 }
547 
548 ///
549 @("mockStruct different return types for different functions")
550 @safe pure unittest {
551     auto m = mockStruct!(ReturnValues!("length", 5), ReturnValues!("greet", "hello"));
552     assert(m.length == 5);
553     assert(m.greet("bar") == "hello");
554     m.expectCalled!"length";
555     m.expectCalled!"greet"("bar");
556 }
557 
558 ///
559 @("mockStruct different return types for different functions and multiple return values")
560 @safe pure unittest {
561     auto m = mockStruct!(ReturnValues!("length", 5, 3), ReturnValues!("greet", "hello", "g'day"));
562     assert(m.length == 5);
563     m.expectCalled!"length";
564     assert(m.length == 3);
565     m.expectCalled!"length";
566 
567     assert(m.greet("bar") == "hello");
568     m.expectCalled!"greet"("bar");
569     assert(m.greet("quux") == "g'day");
570     m.expectCalled!"greet"("quux");
571 }
572 
573 /**
574    A mock struct that always throws.
575  */
576 auto throwStruct(E = from!"unit_threaded.should".UnitTestException, R = void)() {
577 
578     struct Mock {
579 
580         R opDispatch(string funcName, string file = __FILE__, size_t line = __LINE__, V...)(
581                 auto ref V values) {
582             throw new E(funcName ~ " was called", file, line);
583         }
584     }
585 
586     return Mock();
587 }
588 
589 ///
590 @("throwStruct default")
591 @safe pure unittest {
592     import std.exception : assertThrown;
593     import unit_threaded.should : UnitTestException;
594 
595     auto m = throwStruct;
596     assertThrown!UnitTestException(m.foo);
597     assertThrown!UnitTestException(m.bar(1, "foo"));
598 }